# Copyright 2018 Amazon.com, Inc. or its affiliates. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License"). You
# may not use this file except in compliance with the License. A copy of
# the License is located at
#
# http://aws.amazon.com/apache2.0/
#
# or in the "license" file accompanying this file. This file is
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
# ANY KIND, either express or implied. See the License for the specific
# language governing permissions and limitations under the License.

import os
import sys
import glob
import yaml
import logging
import botocore
import tempfile
import shutil
import re
from argparse import Namespace

from botocore.session import get_session

from awscli.testutils import mock, unittest, capture_output
from awscli.customizations.eks.update_kubeconfig import UpdateKubeconfigCommand
from awscli.customizations.eks.exceptions import EKSClusterError
from awscli.customizations.eks.kubeconfig import (Kubeconfig,
                                                  KubeconfigCorruptedError,
                                                  KubeconfigInaccessableError)
from tests.functional.eks.test_util import (describe_cluster_response,
                                            describe_cluster_creating_response,
                                            get_testdata,
                                            assume_role_response)

def sanitize_output(output):
    """
    Trims output and removes all lines after a line starting with warning.
    A line will only start with warning if it is the start of a
    "not installed" warning, which should be ignored when comparing output.
    """
    to_return = ""
    for line in output.splitlines():
        if bool(re.match('warning', line.strip(), re.I)):
            return to_return.strip()
        else:
            to_return += line
            to_return += '\n'
    return to_return.strip()

def build_environment(entries):
    """ Build an environment variable from a list of strings. """
    return os.path.pathsep.join(entries)

class TestUpdateKubeconfig(unittest.TestCase):
    def setUp(self):
        self.create_client_patch = mock.patch(
            'botocore.session.Session.create_client'
        )

        self.mock_create_client = self.create_client_patch.start()
        self.session = get_session()

        self.client = mock.Mock()
        self.client.describe_cluster.return_value = describe_cluster_response()
        self.mock_create_client.return_value = self.client

        # Set up the sts_client_mock
        self.sts_client_mock = mock.Mock()
        self.sts_client_mock.assume_role.return_value = assume_role_response()

        # Ensure the mock_create_client correctly returns the appropriate mock
        self.mock_create_client.side_effect = lambda service_name, **kwargs: (
            self.sts_client_mock if service_name == "sts" else self.client
        )
        
        self.command = UpdateKubeconfigCommand(self.session)
        self.maxDiff = None

    def tearDown(self):
        self.create_client_patch.stop()

    def assert_output(self, captured, file):
        """
        Compares the captured output with the testdata named file
        For approximate equality.
        """
        with open(get_testdata(file)) as f:
            self.assertMultiLineEqual(
                    sanitize_output(captured.stdout.getvalue()),
                    f.read().strip()
                )

    def _get_temp_config(self, config):
        """
        Helper to access a temp config generated by initialize_tempfiles.
        """
        return os.path.join(self._temp_directory, config)

    def initialize_tempfiles(self, files):
        """
        Initializes a directory of tempfiles containing copies of each testdata
        file listed in files.
        Returns the absolute path of the containing directory.

        :param files: A list of filenames found in testdata
        :type files: list
        """
        self._temp_directory = tempfile.mkdtemp()
        self.addCleanup(shutil.rmtree, self._temp_directory)
        if files is not None:
            for file in files:
                shutil.copy2(get_testdata(file),
                            self._get_temp_config(file))
        return self._temp_directory


    def build_temp_environment_variable(self, configs):
        """
        Generate a string which is an environment variable
        containing the paths for each temp file corresponding to configs

        :param configs: The names of the configs in testdata
        to put in the environment variable
        :type configs: list
        """
        return build_environment([self._get_temp_config(config)
                                  for config in configs])

    def assert_config_state(self, config_name, correct_output_name):
        """
        Asserts that the temp config named config_name has the same content
        as the testdata named correct_output_name.
        Should be called after initialize_tempfiles.

        :param config_name: The filename (not the path) of the tempfile
        to compare
        :type config_name: str

        :param correct_output_name: The filename (not the path) of the testdata
        to compare
        :type correct_output_name: str
        """
        with open(self._get_temp_config(config_name)) as file1:
            with open(get_testdata(correct_output_name)) as file2:
                self.assertMultiLineEqual(file1.read().strip(),
                                          file2.read().strip())


    def assert_cmd_dry(self, passed_config,
                       env_variable_configs,
                       default_config=os.path.join(".kube", "config")):
        """
        Run update-kubeconfig using dry-run,
        assert_cmd_dry runs directly referencing the testdata directory,
        since dry_run won't write to file
        The KUBECONFIG environment variable will be set to contain the configs
        listed in env_variable_configs (regardless of whether they exist).
        The default path will be set to default_config
        Returns the captured output

        :param passed_config: A filename to be passed to --kubeconfig
        :type passed_config: string

        :param env_variable_configs: A list of filenames to put in KUBECONFIG
        :type env_variable_configs: list or None

        :param default config: A config to be the default path
        :type default_config: string

        :returns: The captured output
        :rtype: CapturedOutput
        """
        env_variable = self.build_temp_environment_variable(
            env_variable_configs
        )
        args = ["--name", "ExampleCluster", "--dry-run"]
        if passed_config is not None:
            args += ["--kubeconfig", get_testdata(passed_config)]

        with capture_output() as captured:
            with mock.patch.dict(os.environ, {'KUBECONFIG': env_variable}):
                with mock.patch(
                        "awscli.customizations.eks.update_kubeconfig.DEFAULT_PATH",
                        get_testdata(default_config)):
                    self.command(args, None)

        self.mock_create_client.assert_called_once_with('eks')
        self.client \
            .describe_cluster.assert_called_once_with(name='ExampleCluster')

        return captured

    def assert_cmd(self, configs, passed_config,
                   env_variable_configs,
                   default_config=os.path.join(".kube", "config"),
                   verbose=False):
        """
        Run update-kubeconfig in a temp directory,
        This directory will have copies of all testdata files whose names
        are listed in configs.
        The KUBECONFIG environment variable will be set to contain the configs
        listed in env_variable_configs (regardless of whether they exist).
        The default path will be set to default_config

        :param configs: A list of filenames to copy into the temp directory
        :type configs: list

        :param passed_config: A filename to be passed to --kubeconfig
        :type passed_config: string or None

        :param env_variable_configs: A list of filenames to put in KUBECONFIG
        :type env_variable_configs: list

        :param default config: A config to be the default path
        :type default_config: string
        """
        self.initialize_tempfiles(configs)
        env_variable = self.build_temp_environment_variable(
            env_variable_configs
        )
        args = ["--name", "ExampleCluster"]
        if passed_config is not None:
            args += ["--kubeconfig", self._get_temp_config(passed_config)]
        if verbose:
            args += ["--verbose"]

        with mock.patch.dict(os.environ, {'KUBECONFIG': env_variable}):
            with mock.patch(
                "awscli.customizations.eks.update_kubeconfig.DEFAULT_PATH",
                            self._get_temp_config(default_config)):
                self.command(args, None)

        self.mock_create_client.assert_called_once_with('eks')
        self.client\
            .describe_cluster.assert_called_once_with(name='ExampleCluster')

    def test_dry_run_new(self):
        passed = "new_config"
        environment = []

        captured_output = self.assert_cmd_dry(passed, environment)
        self.assert_output(captured_output, 'output_single')

    def test_dry_run_existing(self):
        passed = "valid_existing"
        environment = []

        captured_output = self.assert_cmd_dry(passed, environment)
        self.assert_output(captured_output, 'output_combined')

    def test_dry_run_empty(self):
        passed = "valid_empty_config"
        environment = []

        captured_output = self.assert_cmd_dry(passed, environment)
        self.assert_output(captured_output, 'output_single')

    def test_dry_run_corrupted(self):
        passed = "invalid_string_clusters"
        environment = []

        with self.assertRaises(KubeconfigCorruptedError):
            captured_output = self.assert_cmd_dry(passed, environment)

    def test_write_new(self):
        configs = []
        passed = "new_config"
        environment = []

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("new_config", "output_single")

    def test_use_environment(self):
        configs = ['invalid_string_clusters',
                   'valid_empty_existing',
                   'valid_existing']
        passed = None
        environment = ['does_not_exist',
                       'invalid_string_clusters',
                       'valid_empty_existing',
                       'valid_existing']

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("does_not_exist", "output_single")

    def test_use_default(self):
        configs = ["valid_existing"]
        passed = None
        environment = []
        default = "valid_existing"

        self.assert_cmd(configs, passed, environment, default, verbose=True)
        self.assert_config_state("valid_existing", "output_combined")

    def test_all_corrupted(self):
        configs = ["invalid_string_cluster_entry",
                   "invalid_string_contexts",
                   "invalid_text"]
        passed = None
        environment = ["invalid_string_cluster_entry",
                       "invalid_string_contexts",
                       "invalid_text"]

        with self.assertRaises(KubeconfigCorruptedError):
            self.assert_cmd(configs, passed, environment)

    def test_all_but_one_corrupted(self):
        configs = ["valid_existing",
                   "invalid_string_cluster_entry",
                   "invalid_string_contexts",
                   "invalid_text"]
        passed = None
        environment = ["valid_existing",
                       "invalid_string_cluster_entry",
                       "invalid_string_contexts",
                       "invalid_text"]

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("valid_existing", 'output_combined')

    def test_corrupted_and_missing(self):
        configs = ["invalid_string_clusters",
                   "invalid_string_users"]
        passed = None
        environment = ["invalid_string_clusters",
                       "does_not_exist",
                       "does_not_exist2",
                       "invalid_string_users"]

        with self.assertRaises(KubeconfigCorruptedError):
            self.assert_cmd(configs, passed, environment)

    def test_one_corrupted_environment(self):
        configs = ["invalid_string_clusters"]
        passed = None
        environment = ["invalid_string_clusters"]

        with self.assertRaises(KubeconfigCorruptedError):
            self.assert_cmd(configs, passed, environment)

    def test_environmemt_empty_elements(self):
        configs = ["valid_existing"]

        self.initialize_tempfiles(configs)
        env_variable = build_environment([
            "",
            self._get_temp_config("valid_existing")
        ])
        args = ["--name", "ExampleCluster"]

        with mock.patch.dict(os.environ, {'KUBECONFIG': env_variable}):
            with mock.patch(
                "awscli.customizations.eks.update_kubeconfig.DEFAULT_PATH",
                            self._get_temp_config("default_temp")):
                self.command(args, None)

        self.mock_create_client.assert_called_once_with('eks')
        self.client\
            .describe_cluster.assert_called_once_with(name='ExampleCluster')
        self.assert_config_state("valid_existing", "output_combined")

    def test_environmemt_all_empty(self):
        configs = ["valid_existing"]

        self.initialize_tempfiles(configs)
        env_variable = build_environment(["", ""," ", "\t",""])
        args = ["--name", "ExampleCluster"]

        with mock.patch.dict(os.environ, {'KUBECONFIG': env_variable}):
            with mock.patch(
                "awscli.customizations.eks.update_kubeconfig.DEFAULT_PATH",
                            self._get_temp_config("default_temp")):
                self.command(args, None)

        self.mock_create_client.assert_called_once_with('eks')
        self.client\
            .describe_cluster.assert_called_once_with(name='ExampleCluster')
        self.assert_config_state("default_temp", "output_single")

    def test_default_path_directory(self):
        configs = []
        passed = None
        environment = []
        # Default will be the temp directory once _get_temp_config is called
        default = ""

        with self.assertRaises(KubeconfigInaccessableError):
            self.assert_cmd(configs, passed, environment, default)

    def test_update_existing(self):
        configs = ["valid_old_data"]
        passed = "valid_old_data"
        environment = []

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("valid_old_data", "output_combined")

    def test_update_existing_environment(self):
        configs = ["valid_old_data"]
        passed = None
        environment = ["valid_old_data",
                       "output_combined",
                       "output_single"]

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("valid_old_data", "output_combined")

    def test_cluster_creating(self):
        configs = ["output_combined"]
        passed = "output_combined"
        environment = []
        self.client.describe_cluster =\
            mock.Mock(return_value=describe_cluster_creating_response())
        with self.assertRaises(EKSClusterError):
            self.assert_cmd(configs, passed, environment)

    def test_kubeconfig_order(self):
        configs = ["valid_changed_ordering"]
        passed = "valid_changed_ordering"
        environment = []

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("valid_changed_ordering", "output_combined_changed_ordering")

    def test_update_old_api_version(self):
        configs = ["valid_old_api_version"]
        passed = "valid_old_api_version"
        environment = []

        self.assert_cmd(configs, passed, environment)
        self.assert_config_state("valid_old_api_version", "valid_old_api_version_updated")

    def test_assume_role(self):
        """
        Test that assume_role_arn is handled correctly when provided.
        """
        configs = ["valid_existing"]
        self.initialize_tempfiles(configs)

        # Include the --assume-role-arn argument
        args = [
            "--name", "ExampleCluster",
            "--assume-role-arn", "arn:aws:iam::123456789012:role/test-role"
        ]

        # Mock environment variables and paths
        kubeconfig_path = self._get_temp_config("valid_existing")
        default_path = self._get_temp_config("default_temp")

        with mock.patch.dict(os.environ, {'KUBECONFIG': kubeconfig_path}):
            with mock.patch("awscli.customizations.eks.update_kubeconfig.DEFAULT_PATH", default_path):
                self.command(args, None)

        # Verify that assume_role was called with the correct parameters
        self.sts_client_mock.assume_role.assert_called_once_with(
            RoleArn="arn:aws:iam::123456789012:role/test-role",
            RoleSessionName="EKSDescribeClusterSession"
        )

        # Verify that the EKS client was created with the assumed credentials
        self.mock_create_client.assert_any_call(
            "eks",
            aws_access_key_id="test-access-key",
            aws_secret_access_key="test-secret-key",
            aws_session_token="test-session-token"
        )

        # Verify that the cluster was described
        self.client.describe_cluster.assert_called_once_with(name="ExampleCluster")

        # Assert the configuration state
        self.assert_config_state("valid_existing", "output_combined")

    def test_no_assume_role(self):
        """
        Test that assume_role_arn is not used when not provided.
        """
        configs = ["valid_existing"]
        passed = "valid_existing"
        environment = []

        self.client.describe_cluster = mock.Mock(return_value=describe_cluster_response())
        self.assert_cmd(configs, passed, environment)

        # Verify that assume_role was not called
        self.mock_create_client.assert_called_once_with("eks")
        self.client.describe_cluster.assert_called_once_with(name="ExampleCluster")
